import csv
import random
import numpy as np

def read_cellnumber_table(cellnumber_table_file):
    """
    Read the cellnumber table from a CSV file.

    Args:
    - cellnumber_table_file: The name of the CSV file containing the cellnumber table.

    Returns:
    - A list of tuples representing the cellnumber table entries (division, desired_population, q, r, s).
    """
    cellnumber_table = []
    with open(cellnumber_table_file, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            division = int(row['division'])
            desired_population = int(row['desired_population'])
            q = int(row['q'])
            r = int(row['r'])
            s = int(row['s'])
            cellnumber_table.append((division, desired_population, q, r, s))
    return cellnumber_table

def read_lookup_table(lookup_table_file):
    """
    Read the lookup table from a CSV file.

    Args:
    - lookup_table_file: The name of the CSV file containing the lookup table.

    Returns:
    - A tuple containing lineage, array1, p0_1, p1_0 arrays read from the CSV file.
    """
    lineage_values = []
    array1_values = []
    p0_1_values = []
    p1_0_values = []
    with open(lookup_table_file, 'r') as file:
        reader = csv.reader(file)
        next(reader)  # Skip header
        for row in reader:
            lineage_values.append(int(row[0]))  # Convert the first column to int
            array1_values.append(int(row[1]))  # Convert the second column to int
            p0_1_values.append(float(row[2]))  # Convert the 3rd column to float
            p1_0_values.append(float(row[3]))  # Convert the fourth column to float
    return (lineage_values, array1_values, p0_1_values, p1_0_values)

def shuffle_qrs_pool(q, r, s):
    """
    Create a shuffled qrs pool.

    Args:
    - q: Number of cells with one daughter cell.
    - r: Number of cells with two daughter cells.
    - s: Number of cells with no daughter cells (lost).

    Returns:
    - A shuffled list containing q 'q', r 'r', and s 's' elements.
    """
    qrs_pool = ['q'] * q + ['r'] * r + ['s'] * s
    random.shuffle(qrs_pool)
    return qrs_pool

def cell_division(mother_cells, qrs_pool, cellnumber_table, division):
    """
    Simulate cell division.

    Args:
    - mother_cells: List of mother cells and their lineage arrays.
    - qrs_pool: A list containing q, r, and s markers shuffled randomly.
    - cellnumber_table: The table containing cell number information.
    - division: The current division number.

    Returns:
    - List of daughter cells and their lineage arrays.
    """
    daughter_cells = []
    for mother_cell in mother_cells:
        lineage_array = mother_cell[0]  # Extracting lineage array from the tuple
        array1 = mother_cell[1]
        p0_1 = mother_cell[2]
        p1_0 = mother_cell[3]
        marker = qrs_pool.pop(0)
        if marker == 'q':
            # Create one daughter cell
            daughter_lineage_array = np.copy(lineage_array)
            division_number = cellnumber_table[division][0]
            daughter_lineage_array[division_number] = 1
            daughter_array1 =np.copy(array1)
            daughter_p0_1 =np.copy(p0_1)
            daughter_p1_0 =np.copy(p1_0)
            # Flip bits in daughter_array1 based on probabilities p0_1 and p1_0
            # Assuming p0_1 and p1_0 are arrays of the same length as daughter_array1
            random_numbers = np.random.rand(len(daughter_array1))

            # Flip bits in daughter_array1 from 0 to 2 based on probability p0_1
            daughter_array1[np.logical_and(daughter_array1 == 0, random_numbers < p0_1)] = 2

            # Flip bits in daughter_array1 from 1 to 0 based on probability p1_0, but only for bits that were not flipped to 2 previously
            daughter_array1[np.logical_and(daughter_array1 == 1, random_numbers < p1_0)] = 0

            # Convert all "2"s back to 1 in daughter_array1
            daughter_array1[daughter_array1 == 2] = 1

           

            daughter_cells.append((daughter_lineage_array, daughter_array1, daughter_p0_1, daughter_p1_0))
        elif marker == 'r':
            # Create two daughter cells
            daughter_lineage_array = np.copy(lineage_array)
            division_number = cellnumber_table[division][0]
            daughter_lineage_array[division_number] = 1
            daughter_array1 =np.copy(array1)
            daughter_p0_1 =np.copy(p0_1)
            daughter_p1_0 =np.copy(p1_0)
            # Flip bits in daughter_array1 based on probabilities p0_1 and p1_0
            # Assuming p0_1 and p1_0 are arrays of the same length as daughter_array1
            random_numbers = np.random.rand(len(daughter_array1))

            # Flip bits in daughter_array1 from 0 to 2 based on probability p0_1
            daughter_array1[np.logical_and(daughter_array1 == 0, random_numbers < p0_1)] = 2

            # Flip bits in daughter_array1 from 1 to 0 based on probability p1_0, but only for bits that were not flipped to 2 previously
            daughter_array1[np.logical_and(daughter_array1 == 1, random_numbers < p1_0)] = 0

            
            # Convert all "2"s back to 1 in daughter_array1
            daughter_array1[daughter_array1 == 2] = 1

                        
            daughter_cells.append((daughter_lineage_array, daughter_array1, daughter_p0_1, daughter_p1_0))


            daughter_lineage_array = np.copy(lineage_array)
            division_number = cellnumber_table[division][0]
            daughter_lineage_array[division_number] = 2
            daughter_array1 =np.copy(array1)
            daughter_p0_1 =np.copy(p0_1)
            daughter_p1_0 =np.copy(p1_0)
            # Flip bits in daughter_array1 based on probabilities p0_1 and p1_0
            # Assuming p0_1 and p1_0 are arrays of the same length as daughter_array1
            random_numbers = np.random.rand(len(daughter_array1))

            # Flip bits in daughter_array1 from 0 to 2 based on probability p0_1
            daughter_array1[np.logical_and(daughter_array1 == 0, random_numbers < p0_1)] = 2

            # Flip bits in daughter_array1 from 1 to 0 based on probability p1_0, but only for bits that were not flipped to 2 previously
            daughter_array1[np.logical_and(daughter_array1 == 1, random_numbers < p1_0)] = 0

            
            # Convert all "2"s back to 1 in daughter_array1
            daughter_array1[daughter_array1 == 2] = 1

              
            daughter_cells.append((daughter_lineage_array, daughter_array1, daughter_p0_1, daughter_p1_0))
    return daughter_cells

def calculate_and_write_average_bit_values(rows, run):
    num_bits = len(rows[0][2].split(','))  # Assuming Array1 is in column 2
    num_cells = len(rows)

    # Initialize a list to store the sum of bit values for each position in Array1
    array1_bit_sums = [0] * num_bits

    # Loop through all rows and accumulate the bit values for Array1
    for row in rows:
        array1 = parse_bit_array(row[2])  # Assuming Array1 is in column 2
        for i, bit in enumerate(array1):
            array1_bit_sums[i] += bit

    # Calculate the average bit value for each position in Array1
    average_array1_bit_values = [bit_sum / num_cells for bit_sum in array1_bit_sums]

    # Calculate variance of bit values
    variance = np.var(average_array1_bit_values)

    # Write average bit values and variance to sumlineage_arrays.csv
    with open('sumlineage_arrays.csv', 'a', newline='') as csvfile:  # Append mode
        writer = csv.writer(csvfile)
        if run == 1:
            writer.writerow(['Run'] + [f'bit{i+1}' for i in range(len(average_array1_bit_values))] + ['Variance'])
        writer.writerow([f'Run{run}'] + [str(bit_value) for bit_value in average_array1_bit_values] + [str(variance)])

def main():
    num_runs = 1  # Specify the number of runs
    last_run_lineage_arrays = None
    write_frequency = 50  # Specify how often to write to the file (every 5 divisions)

    for run in range(1, num_runs + 1):
        print(f"Start of Run {run}")

        cellnumber_table_file = 'cellnumber_table.csv'
        cellnumber_table = read_cellnumber_table(cellnumber_table_file)

        # Start with a single mother cell
        lookup_table_file = 'lookup5_table_Xonly.csv'
        lineage_values, array1_values, p0_1_values, p1_0_values = read_lookup_table(lookup_table_file)
        mother_lineage_array = np.array(lineage_values, dtype=int)
        array1 = np.array(array1_values, dtype=int)
        p0_1 = np.array(p0_1_values, dtype=float)
        p1_0 = np.array(p1_0_values, dtype=float)
        mother_cells = [(mother_lineage_array, array1, p0_1, p1_0)]

        division_counter = 0  # Initialize the division counter

        for division, desired_population, q, r, s in cellnumber_table:
            division_counter += 1  # Increment division counter
            print(f"Division {division}, Number of mother cells: {len(mother_cells)}")

            # Shuffle qrs pool
            qrs_pool = shuffle_qrs_pool(q, r, s)

            # Simulate cell division
            daughter_cells = cell_division(mother_cells, qrs_pool, cellnumber_table, division)

            # Discard mother cells
            mother_cells = []

            # Update mother_cells with daughter cells
            mother_cells = daughter_cells

            # Check if it's time to write to the file
            if division_counter % write_frequency == 0 or division == cellnumber_table[-1][0]:
                # Read lineage_arrays.csv and calculate average bit values
                rows = [[f'Daughter Cell {i+1}', ','.join(map(str, lineage)), ','.join(map(str, array1))] for i, (lineage, array1, _, _) in enumerate(mother_cells)]
                calculate_and_write_average_bit_values(rows, run)

        if run == num_runs:
            # Save lineage arrays of the last run
            last_run_lineage_arrays = mother_cells

        print(f"End of Run {run}")

    # Write lineage arrays of the last run to a CSV file
    with open('last_run_lineage_arrays.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Daughter Cell', 'Lineage Array', 'Array1'])
        for i, (lineage, array1, _, _) in enumerate(last_run_lineage_arrays):
            writer.writerow([f'Daughter Cell {i+1}', ','.join(map(str, lineage)), ','.join(map(str, array1))])

def parse_bit_array(array_string):
    return [int(bit) for bit in array_string.split(',')]

if __name__ == "__main__":
    main()
